{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "view-in-github"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/jeshraghian/snntorch/blob/master/examples/tutorial_2_lif_neuron.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "HzIQBw28NL8h"
      },
      "source": [
        "[<img src='https://github.com/jeshraghian/snntorch/blob/master/docs/_static/img/snntorch_alpha_w.png?raw=true' width=\"400\">](https://github.com/jeshraghian/snntorch/)\n",
        "\n",
        "# snnTorch - The Leaky Integrate and Fire Neuron\n",
        "## Tutorial 2\n",
        "### By Jason K. Eshraghian (www.jasoneshraghian.com)\n",
        "\n",
        "<a href=\"https://colab.research.google.com/github/jeshraghian/snntorch/blob/master/examples/tutorial_2_lif_neuron.ipynb\">\n",
        "  <img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/>\n",
        "</a>\n",
        "\n",
        "[<img src='https://github.com/jeshraghian/snntorch/blob/master/docs/_static/img/GitHub-Mark-Light-120px-plus.png?raw=true' width=\"28\">](https://github.com/jeshraghian/snntorch/) [<img src='https://github.com/jeshraghian/snntorch/blob/master/docs/_static/img/GitHub_Logo_White.png?raw=true' width=\"80\">](https://github.com/jeshraghian/snntorch/)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "H56eQUdm3IuT"
      },
      "source": [
        "The snnTorch tutorial series is based on the following paper. If you find these resources or code useful in your work, please consider citing the following source:\n",
        "\n",
        "> <cite> [Jason K. Eshraghian, Max Ward, Emre Neftci, Xinxin Wang, Gregor Lenz, Girish Dwivedi, Mohammed Bennamoun, Doo Seok Jeong, and Wei D. Lu. \"Training Spiking Neural Networks Using Lessons From Deep Learning\". arXiv preprint arXiv:2109.12894, September 2021.](https://arxiv.org/abs/2109.12894) </cite>\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Ep_Qv7kzNOz6"
      },
      "source": [
        "# Introduction\n",
        "In this tutorial, you will:\n",
        "* Learn the fundamentals of the leaky integrate-and-fire (LIF) neuron model\n",
        "* Use snnTorch to implement a first order LIF neuron\n",
        "\n",
        "Install the latest PyPi distribution of snnTorch by clicking into the following cell and pressing `Shift+Enter`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "SPQITvDuNNJg"
      },
      "outputs": [],
      "source": [
        "!pip install snntorch"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "31NkMZnxBFZ4"
      },
      "outputs": [],
      "source": [
        "# imports\n",
        "import snntorch as snn\n",
        "from snntorch import spikeplot as splt\n",
        "from snntorch import spikegen\n",
        "\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "\n",
        "import numpy as np\n",
        "import matplotlib.pyplot as plt"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "2kBGXe5K_xWh"
      },
      "outputs": [],
      "source": [
        "#@title Plotting Settings\n",
        "def plot_mem(mem, title=False):\n",
        "  if title:\n",
        "    plt.title(title)\n",
        "  plt.plot(mem)\n",
        "  plt.xlabel(\"Time step\")\n",
        "  plt.ylabel(\"Membrane Potential\")\n",
        "  plt.xlim([0, 50])\n",
        "  plt.ylim([0, 1])\n",
        "  plt.show()\n",
        "\n",
        "def plot_step_current_response(cur_in, mem_rec, vline1):\n",
        "  fig, ax = plt.subplots(2, figsize=(8,6),sharex=True)\n",
        "\n",
        "  # Plot input current\n",
        "  ax[0].plot(cur_in, c=\"tab:orange\")\n",
        "  ax[0].set_ylim([0, 0.2])\n",
        "  ax[0].set_ylabel(\"Input Current ($I_{in}$)\")\n",
        "  ax[0].set_title(\"Lapicque's Neuron Model With Step Input\")\n",
        "\n",
        "  # Plot membrane potential\n",
        "  ax[1].plot(mem_rec)\n",
        "  ax[1].set_ylim([0, 0.6]) \n",
        "  ax[1].set_ylabel(\"Membrane Potential ($U_{mem}$)\")\n",
        "\n",
        "  if vline1:\n",
        "    ax[1].axvline(x=vline1, ymin=0, ymax=2.2, alpha = 0.25, linestyle=\"dashed\", c=\"black\", linewidth=2, zorder=0, clip_on=False)\n",
        "  plt.xlabel(\"Time step\")\n",
        "\n",
        "  plt.show()\n",
        "\n",
        "\n",
        "def plot_current_pulse_response(cur_in, mem_rec, title, vline1=False, vline2=False, ylim_max1=False):\n",
        "\n",
        "  fig, ax = plt.subplots(2, figsize=(8,6),sharex=True)\n",
        "\n",
        "  # Plot input current\n",
        "  ax[0].plot(cur_in, c=\"tab:orange\")\n",
        "  if not ylim_max1:\n",
        "    ax[0].set_ylim([0, 0.2])\n",
        "  else:\n",
        "    ax[0].set_ylim([0, ylim_max1])\n",
        "  ax[0].set_ylabel(\"Input Current ($I_{in}$)\")\n",
        "  ax[0].set_title(title)\n",
        "\n",
        "  # Plot membrane potential\n",
        "  ax[1].plot(mem_rec)\n",
        "  ax[1].set_ylim([0, 1])\n",
        "  ax[1].set_ylabel(\"Membrane Potential ($U_{mem}$)\")\n",
        "\n",
        "  if vline1:\n",
        "    ax[1].axvline(x=vline1, ymin=0, ymax=2.2, alpha = 0.25, linestyle=\"dashed\", c=\"black\", linewidth=2, zorder=0, clip_on=False)\n",
        "  if vline2:\n",
        "    ax[1].axvline(x=vline2, ymin=0, ymax=2.2, alpha = 0.25, linestyle=\"dashed\", c=\"black\", linewidth=2, zorder=0, clip_on=False)\n",
        "  plt.xlabel(\"Time step\")\n",
        "\n",
        "  plt.show()\n",
        "\n",
        "def compare_plots(cur1, cur2, cur3, mem1, mem2, mem3, vline1, vline2, vline3, vline4, title):\n",
        "  # Generate Plots\n",
        "  fig, ax = plt.subplots(2, figsize=(8,6),sharex=True)\n",
        "\n",
        "  # Plot input current\n",
        "  ax[0].plot(cur1)\n",
        "  ax[0].plot(cur2)\n",
        "  ax[0].plot(cur3)\n",
        "  ax[0].set_ylim([0, 0.2])\n",
        "  ax[0].set_ylabel(\"Input Current ($I_{in}$)\")\n",
        "  ax[0].set_title(title)\n",
        "\n",
        "  # Plot membrane potential\n",
        "  ax[1].plot(mem1)\n",
        "  ax[1].plot(mem2)\n",
        "  ax[1].plot(mem3)\n",
        "  ax[1].set_ylim([0, 1])\n",
        "  ax[1].set_ylabel(\"Membrane Potential ($U_{mem}$)\")\n",
        "\n",
        "  ax[1].axvline(x=vline1, ymin=0, ymax=2.2, alpha = 0.25, linestyle=\"dashed\", c=\"black\", linewidth=2, zorder=0, clip_on=False)\n",
        "  ax[1].axvline(x=vline2, ymin=0, ymax=2.2, alpha = 0.25, linestyle=\"dashed\", c=\"black\", linewidth=2, zorder=0, clip_on=False)\n",
        "  ax[1].axvline(x=vline3, ymin=0, ymax=2.2, alpha = 0.25, linestyle=\"dashed\", c=\"black\", linewidth=2, zorder=0, clip_on=False)\n",
        "  ax[1].axvline(x=vline4, ymin=0, ymax=2.2, alpha = 0.25, linestyle=\"dashed\", c=\"black\", linewidth=2, zorder=0, clip_on=False)\n",
        "\n",
        "  plt.xlabel(\"Time step\")\n",
        "\n",
        "  plt.show()\n",
        "\n",
        "def plot_cur_mem_spk(cur, mem, spk, thr_line=False, vline=False, title=False, ylim_max2=1.25):\n",
        "  # Generate Plots\n",
        "  fig, ax = plt.subplots(3, figsize=(8,6), sharex=True, \n",
        "                        gridspec_kw = {'height_ratios': [1, 1, 0.4]})\n",
        "\n",
        "  # Plot input current\n",
        "  ax[0].plot(cur, c=\"tab:orange\")\n",
        "  ax[0].set_ylim([0, 0.4])\n",
        "  ax[0].set_xlim([0, 200])\n",
        "  ax[0].set_ylabel(\"Input Current ($I_{in}$)\")\n",
        "  if title:\n",
        "    ax[0].set_title(title)\n",
        "\n",
        "  # Plot membrane potential\n",
        "  ax[1].plot(mem)\n",
        "  ax[1].set_ylim([0, ylim_max2]) \n",
        "  ax[1].set_ylabel(\"Membrane Potential ($U_{mem}$)\")\n",
        "  if thr_line:\n",
        "    ax[1].axhline(y=thr_line, alpha=0.25, linestyle=\"dashed\", c=\"black\", linewidth=2)\n",
        "  plt.xlabel(\"Time step\")\n",
        "\n",
        "  # Plot output spike using spikeplot\n",
        "  splt.raster(spk, ax[2], s=400, c=\"black\", marker=\"|\")\n",
        "  if vline:\n",
        "    ax[2].axvline(x=vline, ymin=0, ymax=6.75, alpha = 0.15, linestyle=\"dashed\", c=\"black\", linewidth=2, zorder=0, clip_on=False)\n",
        "  plt.ylabel(\"Output spikes\")\n",
        "  plt.yticks([]) \n",
        "\n",
        "  plt.show()\n",
        "\n",
        "def plot_spk_mem_spk(spk_in, mem, spk_out, title):\n",
        "  # Generate Plots\n",
        "  fig, ax = plt.subplots(3, figsize=(8,6), sharex=True, \n",
        "                        gridspec_kw = {'height_ratios': [0.4, 1, 0.4]})\n",
        "\n",
        "  # Plot input current\n",
        "  splt.raster(spk_in, ax[0], s=400, c=\"black\", marker=\"|\")\n",
        "  ax[0].set_ylabel(\"Input Spikes\")\n",
        "  ax[0].set_title(title)\n",
        "  plt.yticks([]) \n",
        "\n",
        "  # Plot membrane potential\n",
        "  ax[1].plot(mem)\n",
        "  ax[1].set_ylim([0, 1])\n",
        "  ax[1].set_ylabel(\"Membrane Potential ($U_{mem}$)\")\n",
        "  ax[1].axhline(y=0.5, alpha=0.25, linestyle=\"dashed\", c=\"black\", linewidth=2)\n",
        "  plt.xlabel(\"Time step\")\n",
        "\n",
        "  # Plot output spike using spikeplot\n",
        "  splt.raster(spk_rec, ax[2], s=400, c=\"black\", marker=\"|\")\n",
        "  plt.ylabel(\"Output spikes\")\n",
        "  plt.yticks([]) \n",
        "\n",
        "  plt.show()\n",
        "\n",
        "\n",
        "def plot_reset_comparison(spk_in, mem_rec, spk_rec, mem_rec0, spk_rec0):\n",
        "  # Generate Plots to Compare Reset Mechanisms\n",
        "  fig, ax = plt.subplots(nrows=3, ncols=2, figsize=(10,6), sharex=True, \n",
        "                        gridspec_kw = {'height_ratios': [0.4, 1, 0.4], 'wspace':0.05})\n",
        "\n",
        "  # Reset by Subtraction: input spikes\n",
        "  splt.raster(spk_in, ax[0][0], s=400, c=\"black\", marker=\"|\")\n",
        "  ax[0][0].set_ylabel(\"Input Spikes\")\n",
        "  ax[0][0].set_title(\"Reset by Subtraction\")\n",
        "  ax[0][0].set_yticks([])\n",
        "\n",
        "  # Reset by Subtraction: membrane potential \n",
        "  ax[1][0].plot(mem_rec)\n",
        "  ax[1][0].set_ylim([0, 0.7])\n",
        "  ax[1][0].set_ylabel(\"Membrane Potential ($U_{mem}$)\")\n",
        "  ax[1][0].axhline(y=0.5, alpha=0.25, linestyle=\"dashed\", c=\"black\", linewidth=2)\n",
        "\n",
        "  # Reset by Subtraction: output spikes\n",
        "  splt.raster(spk_rec, ax[2][0], s=400, c=\"black\", marker=\"|\")\n",
        "  ax[2][0].set_yticks([])\n",
        "  ax[2][0].set_xlabel(\"Time step\")\n",
        "  ax[2][0].set_ylabel(\"Output Spikes\")\n",
        "\n",
        "  # Reset to Zero: input spikes\n",
        "  splt.raster(spk_in, ax[0][1], s=400, c=\"black\", marker=\"|\")\n",
        "  ax[0][1].set_title(\"Reset to Zero\")\n",
        "  ax[0][1].set_yticks([])\n",
        "\n",
        "  # Reset to Zero: membrane potential\n",
        "  ax[1][1].plot(mem_rec0)\n",
        "  ax[1][1].set_ylim([0, 0.7])\n",
        "  ax[1][1].axhline(y=0.5, alpha=0.25, linestyle=\"dashed\", c=\"black\", linewidth=2)\n",
        "  ax[1][1].set_yticks([])\n",
        "  ax[2][1].set_xlabel(\"Time step\")\n",
        "\n",
        "  # Reset to Zero: output spikes\n",
        "  splt.raster(spk_rec0, ax[2][1], s=400, c=\"black\", marker=\"|\")\n",
        "  ax[2][1].set_yticks([])\n",
        "\n",
        "  plt.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "xmTt5dvyXNNy"
      },
      "source": [
        "# 1. The Spectrum of Neuron Models\n",
        "A large variety of neuron models are out there, ranging from biophysically accurate models (i.e., the Hodgkin-Huxley models) to the extremely simple artificial neuron that pervades all facets of modern deep learning.\n",
        "\n",
        "**Hodgkin-Huxley Neuron Models**$-$While biophysical models can reproduce electrophysiological results with a high degree of accuracy, their complexity makes them difficult to use at present.\n",
        "\n",
        "**Artificial Neuron Model**$-$On the other end of the spectrum is the artificial neuron. The inputs are multiplied by their corresponding weights and passed through an activation function. This simplification has enabled deep learning researchers to perform incredible feats in computer vision, natural language processing, and many other machine learning-domain tasks.\n",
        "\n",
        "**Leaky Integrate-and-Fire Neuron Models**$-$Somewhere in the middle of the divide lies the leaky integrate-and-fire (LIF) neuron model. It takes the sum of weighted inputs, much like the artificial neuron. But rather than passing it directly to an activation function, it will integrate the input over time with a leakage, much like an RC circuit. If the integrated value exceeds a threshold, then the LIF neuron will emit a voltage spike. The LIF neuron abstracts away the shape and profile of the output spike; it is simply treated as a discrete event. As a result, information is not stored within the spike, but rather the timing (or frequency) of spikes. Simple spiking neuron models have produced much insight into the neural code, memory, network dynamics, and more recently, deep learning. The LIF neuron sits in the sweet spot between biological plausibility and practicality. \n",
        "\n",
        "<center>\n",
        "<img src='https://github.com/jeshraghian/snntorch/blob/master/docs/_static/img/examples/tutorial2/2_1_neuronmodels.png?raw=true' width=\"1000\">\n",
        "</center>\n",
        "\n",
        "\n",
        "The different versions of the LIF model each have their own dynamics and use-cases. snnTorch currently supports the following LIF neurons:\n",
        "* Lapicque's RC model: ``snntorch.Lapicque``\n",
        "* 1st-order model: ``snntorch.Leaky`` \n",
        "* Synaptic Conductance-based neuron model: ``snntorch.Synaptic``\n",
        "* Recurrent 1st-order model: ``snntorch.RLeaky``\n",
        "* Recurrent Synaptic Conductance-based neuron model: ``snntorch.RSynaptic``\n",
        "* Alpha neuron Model: ``snntorch.Alpha``\n",
        "\n",
        "Several other non-LIF spiking neurons are also available. This tutorial focuses on the first of these models. This will be used to build towards the other models in [subsequent tutorials](https://snntorch.readthedocs.io/en/latest/tutorials/index.html)."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Nea8oBorr_KZ"
      },
      "source": [
        "# 2. The Leaky Integrate-and-Fire Neuron Model"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "YKsrN5feQ2Dz"
      },
      "source": [
        "## 2.1 Spiking Neurons: Intuition"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "YGStnfjzsGKb"
      },
      "source": [
        "In our brains, a neuron might be connected to 1,000 $-$ 10,000 other neurons. If one neuron spikes, all downhill neurons might feel it. But what determines whether a neuron spikes in the first place? The past century of experiments demonstrate that if a neuron experiences *sufficient* stimulus at its input, then it might become excited and fire its own spike. \n",
        "\n",
        "Where does this stimulus come from? It could be from:\n",
        "* the sensory periphery, \n",
        "* an invasive electrode artificially stimulating the neuron, or in most cases,\n",
        "* from other pre-synaptic neurons. \n",
        "\n",
        "<center>\n",
        "<img src='https://github.com/jeshraghian/snntorch/blob/master/docs/_static/img/examples/tutorial2/2_2_intuition.png?raw=true' width=\"500\">\n",
        "</center>\n",
        "\n",
        "Given that these spikes are very short bursts of electrical activity, it is quite unlikely for all input spikes to arrive at the neuron body in precise unison. This indicates the presence of temporal dynamics that 'sustain' the input spikes, kind of like a delay.\n",
        "\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "l44jI7A2ReB_"
      },
      "source": [
        "## 2.2 The Passive Membrane"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Od5tfXv6_wcq"
      },
      "source": [
        "Like all cells, a neuron is surrounded by a thin membrane. This membrane is a lipid bilayer that insulates the conductive saline solution within the neuron from the extracellular medium. Electrically, the two conductive solutions separated by an insulator act as a capacitor. \n",
        "\n",
        "Another function of this membrane is to control what goes in and out of this cell (e.g., ions such as Na$^+$). The membrane is usually impermeable to ions which blocks them from entering and exiting the neuron body. But there are specific channels in the membrane that are triggered to open by injecting current into the neuron. This charge movement is electrically modelled by a resistor.\n",
        "\n",
        "<center>\n",
        "<img src='https://github.com/jeshraghian/snntorch/blob/master/docs/_static/img/examples/tutorial2/2_3_passivemembrane.png?raw=true' width=\"450\">\n",
        "</center>\n",
        "\n",
        "The following block will derive the behaviour of a LIF neuron from scratch. If you'd prefer to skip the math, then feel free to scroll on by; we'll take a more hands-on approach to understanding the LIF neuron dynamics after the derivation. \n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "wT_S4b4HL2ir"
      },
      "source": [
        "---\n",
        "**Optional: Derivation of LIF Neuron Model**\n",
        "\n",
        "Now say some arbitrary time-varying current $I_{\\rm in}(t)$ is injected into the neuron, be it via electrical stimulation or from other neurons. The total current in the circuit is conserved, so:\n",
        "\n",
        "$$I_{\\rm in}(t) = I_{R} + I_{C}$$\n",
        "\n",
        "From Ohm's Law, the membrane potential measured between the inside and outside of the neuron $U_{\\rm mem}$ is proportional to the current through the resistor:\n",
        "\n",
        "$$I_{R}(t) = \\frac{U_{\\rm mem}(t)}{R}$$\n",
        "\n",
        "The capacitance is a proportionality constant between the charge stored on the capacitor $Q$ and $U_{\\rm mem}(t)$:\n",
        "\n",
        "\n",
        "$$Q = CU_{\\rm mem}(t)$$\n",
        "\n",
        "The rate of change of charge gives the capacitive current:\n",
        "\n",
        "$$\\frac{dQ}{dt}=I_C(t) = C\\frac{dU_{\\rm mem}(t)}{dt}$$\n",
        "\n",
        "Therefore:\n",
        "\n",
        "$$I_{\\rm in}(t) = \\frac{U_{\\rm mem}(t)}{R} + C\\frac{dU_{\\rm mem}(t)}{dt}$$\n",
        "\n",
        "$$\\implies RC \\frac{dU_{\\rm mem}(t)}{dt} = -U_{\\rm mem}(t) + RI_{\\rm in}(t)$$\n",
        "\n",
        "The right hand side of the equation is of units **\\[Voltage]**. On the left hand side of the equation, the term $\\frac{dU_{\\rm mem}(t)}{dt}$ is of units **\\[Voltage/Time]**. To equate it to the left hand side (i.e., voltage), $RC$ must be of unit **\\[Time]**. We refer to $\\tau = RC$ as the time constant of the circuit:\n",
        "\n",
        "$$ \\tau \\frac{dU_{\\rm mem}(t)}{dt} = -U_{\\rm mem}(t) + RI_{\\rm in}(t)$$\n",
        "\n",
        "The passive membrane is therefore described by a linear differential equation.\n",
        "\n",
        "For a derivative of a function to be of the same form as the original function, i.e., $\\frac{dU_{\\rm mem}(t)}{dt} \\propto U_{\\rm mem}(t)$, this implies the solution is exponential with a time constant $\\tau$.\n",
        "\n",
        "Say the neuron starts at some value $U_{0}$ with no further input, i.e., $I_{\\rm in}(t)=0$. The solution of the linear differential equation is:\n",
        "\n",
        "$$U_{\\rm mem}(t) = U_0e^{-\\frac{t}{\\tau}}$$\n",
        "\n",
        "The general solution is shown below.\n",
        "\n",
        "<center>\n",
        "<img src='https://github.com/jeshraghian/snntorch/blob/master/docs/_static/img/examples/tutorial2/2_4_RCmembrane.png?raw=true' width=\"500\">\n",
        "</center>\n",
        "\n",
        "---"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "a3v_nLcCW0TO"
      },
      "source": [
        "**Optional: Forward Euler Method to Solving the LIF Neuron Model** \n",
        "\n",
        "We managed to find the analytical solution to the LIF neuron, but it is unclear how this might be useful in a neural network. This time, let's instead use the forward Euler method to solve the previous linear ordinary differential equation (ODE). This approach might seem arduous, but it gives us a discrete, recurrent representation of the LIF neuron. Once we reach this solution, it can be applied directly to a neural network. As before, the linear ODE describing the RC circuit is:\n",
        "\n",
        "$$\\tau \\frac{dU(t)}{dt} = -U(t) + RI_{\\rm in}(t)$$\n",
        "\n",
        "The subscript from $U(t)$ is omitted for simplicity.\n",
        "\n",
        "First, let's solve this derivative without taking the limit $\\Delta t \\rightarrow 0$:\n",
        "\n",
        "$$\\tau \\frac{U(t+\\Delta t)-U(t)}{\\Delta t} = -U(t) + RI_{\\rm in}(t)$$\n",
        "\n",
        "For a small enough $\\Delta t$, this gives a good enough approximation of continuous-time integration. Isolating the membrane at the following time step gives:\n",
        "\n",
        "$$U(t+\\Delta t) = U(t) + \\frac{\\Delta t}{\\tau}\\big(-U(t) + RI_{\\rm in}(t)\\big)$$\n",
        "\n",
        "The following function represents this equation:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "kFJTqNLQZyam"
      },
      "outputs": [],
      "source": [
        "def leaky_integrate_neuron(U, time_step=1e-3, I=0, R=5e7, C=1e-10):\n",
        "  tau = R*C\n",
        "  U = U + (time_step/tau)*(-U + I*R)\n",
        "  return U"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "uBk0fvH8aQkU"
      },
      "source": [
        "The default values are set to $R=50 M\\Omega$ and $C=100pF$ (i.e., $\\tau=5ms$). These are quite realistic with respect to biological neurons.\n",
        "\n",
        "Now loop through this function, iterating one time step at a time. \n",
        "The membrane potential is initialized at $U=0.9 V$, with the assumption that there is no injected input current, $I_{\\rm in}=0 A$.\n",
        "The simulation is performed with a millisecond precision $\\Delta t=1\\times 10^{-3}$s. "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ziFpaBigah5D"
      },
      "outputs": [],
      "source": [
        "num_steps = 100\n",
        "U = 0.9\n",
        "U_trace = []  # keeps a record of U for plotting\n",
        "\n",
        "for step in range(num_steps):\n",
        "  U_trace.append(U)\n",
        "  U = leaky_integrate_neuron(U)  # solve next step of U\n",
        "\n",
        "plot_mem(U_trace, \"Leaky Neuron Model\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Vd41ihCAfAMr"
      },
      "source": [
        "This exponential decay seems to match what we expected! "
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "rTXS_aSmRs-3"
      },
      "source": [
        "# 3 Lapicque's LIF Neuron Model"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Jt2q6tZiWkgT"
      },
      "source": [
        "This similarity between nerve membranes and RC circuits was observed by [Louis Lapicque in 1907](https://core.ac.uk/download/pdf/21172797.pdf). He stimulated the nerve fiber of a frog with a brief electrical pulse, and found that neuron membranes could be approximated as a capacitor with a leakage. We pay homage to his findings by naming the basic LIF neuron model in snnTorch after him. \n",
        "\n",
        "Most of the concepts in Lapicque's model carry forward to other LIF neuron models. Now it's time to simulate this neuron using snnTorch."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "WtveAGG0zE0n"
      },
      "source": [
        "## 3.1 Lapicque: Without Stimulus"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "qi5EnND98cz3"
      },
      "source": [
        "Instantiate Lapicque's neuron using the following line of code. R & C are changed to simpler values, while keeping the previous time constant of $\\tau=5\\times10^{-3}$s."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "wI2Nahsg8d-t"
      },
      "outputs": [],
      "source": [
        "time_step = 1e-3\n",
        "R = 5\n",
        "C = 1e-3\n",
        "\n",
        "# leaky integrate and fire neuron, tau=5e-3\n",
        "lif1 = snn.Lapicque(R=R, C=C, time_step=time_step)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "kuytxXet8lc9"
      },
      "source": [
        "The neuron model is now stored in `lif1`. To use this neuron: \n",
        "\n",
        "**Inputs**\n",
        "* `spk_in`: each element of $I_{\\rm in}$ is sequentially passed as an input (0 for now)\n",
        "* `mem`: the membrane potential, previously $U[t]$, is also passed as input. Initialize it arbitrarily as $U[0] = 0.9~V$.\n",
        "\n",
        "**Outputs**\n",
        "* `spk_out`: output spike $S_{\\rm out}[t+\\Delta t]$ at the next time step ('1' if there is a spike; '0' if there is no spike)\n",
        "* `mem`: membrane potential $U_{\\rm mem}[t+\\Delta t]$ at the next time step\n",
        "\n",
        "These all need to be of type `torch.Tensor`.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "M_0O1Q3Y9KgM"
      },
      "outputs": [],
      "source": [
        "# Initialize membrane, input, and output\n",
        "mem = torch.ones(1) * 0.9  # U=0.9 at t=0\n",
        "cur_in = torch.zeros(num_steps)  # I=0 for all t \n",
        "spk_out = torch.zeros(1)  # initialize output spikes"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_CEuNHsN8-67"
      },
      "source": [
        "These values are only for the initial time step $t=0$. \n",
        "To analyze the evolution of `mem` over time, create a list `mem_rec` to record these values at every time step."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "CUCxQBTQ9d_P"
      },
      "outputs": [],
      "source": [
        "# A list to store recordings of membrane potential\n",
        "mem_rec = [mem]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "i4TxVFaV9uf6"
      },
      "source": [
        "Now it's time to run a simulation! At each time step, `mem` is updated and stored in `mem_rec`:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "j6GeMMiU8kc4"
      },
      "outputs": [],
      "source": [
        "# pass updated value of mem and cur_in[step]=0 at every time step\n",
        "for step in range(num_steps):\n",
        "  spk_out, mem = lif1(cur_in[step], mem)\n",
        "\n",
        "  # Store recordings of membrane potential\n",
        "  mem_rec.append(mem)\n",
        "\n",
        "# crunch the list of tensors into one tensor\n",
        "mem_rec = torch.stack(mem_rec)\n",
        "\n",
        "plot_mem(mem_rec, \"Lapicque's Neuron Model Without Stimulus\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9k0PKng2JYcx"
      },
      "source": [
        "The membrane potential decays over time in the absence of any input stimuli. "
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "B6GZUEFgJ0l0"
      },
      "source": [
        "## 3.2 Lapicque: Step Input\n",
        "\n",
        "Now apply a step current $I_{\\rm in}(t)$ that switches on at $t=t_0$. Given the linear first-order differential equation:\n",
        "\n",
        "$$ \\tau \\frac{dU_{\\rm mem}}{dt} = -U_{\\rm mem} + RI_{\\rm in}(t),$$\n",
        "\n",
        "the general solution is:\n",
        "\n",
        "$$U_{\\rm mem}=I_{\\rm in}(t)R + [U_0 - I_{\\rm in}(t)R]e^{-\\frac{t}{\\tau}}$$\n",
        "\n",
        "If the membrane potential is initialized to $U_{\\rm mem}(t=0) = 0 V$, then:\n",
        "\n",
        "$$U_{\\rm mem}(t)=I_{\\rm in}(t)R [1 - e^{-\\frac{t}{\\tau}}]$$\n",
        "\n",
        "Based on this explicit time-dependent form, we expect $U_{\\rm mem}$ to relax exponentially towards $I_{\\rm in}R$. Let's visualize what this looks like by triggering a current pulse of $I_{in}=100mA$ at $t_0 = 10ms$."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "kD7cLKSuLu9n"
      },
      "outputs": [],
      "source": [
        "# Initialize input current pulse\n",
        "cur_in = torch.cat((torch.zeros(10), torch.ones(190)*0.1), 0)  # input current turns on at t=10\n",
        "\n",
        "# Initialize membrane, output and recordings\n",
        "mem = torch.zeros(1)  # membrane potential of 0 at t=0\n",
        "spk_out = torch.zeros(1)  # neuron needs somewhere to sequentially dump its output spikes\n",
        "mem_rec = [mem]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "PW8ZDbKCNW8E"
      },
      "source": [
        "This time, the new values of `cur_in` are passed to the neuron:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "w9J2AsrZNXr8"
      },
      "outputs": [],
      "source": [
        "num_steps = 200\n",
        "\n",
        "# pass updated value of mem and cur_in[step] at every time step\n",
        "for step in range(num_steps):\n",
        "  spk_out, mem = lif1(cur_in[step], mem)\n",
        "  mem_rec.append(mem)\n",
        "\n",
        "# crunch -list- of tensors into one tensor\n",
        "mem_rec = torch.stack(mem_rec)\n",
        "\n",
        "plot_step_current_response(cur_in, mem_rec, 10)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "uZi44ZEvNhkR"
      },
      "source": [
        "As $t\\rightarrow \\infty$, the membrane potential $U_{\\rm mem}$ exponentially relaxes to $I_{\\rm in}R$:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "1l40Efnbk92o"
      },
      "outputs": [],
      "source": [
        "print(f\"The calculated value of input pulse [A] x resistance [Ω] is: {cur_in[11]*lif1.R} V\")\n",
        "print(f\"The simulated value of steady-state membrane potential is: {mem_rec[200][0]} V\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "H0ql9cAzpN5D"
      },
      "source": [
        "Close enough!"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "HOgCeLPYaLKZ"
      },
      "source": [
        "## 3.3 Lapicque: Pulse Input\n",
        "\n",
        "Now what if the step input was clipped at $t=30ms$?"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "hjtupQNxaQ5D"
      },
      "outputs": [],
      "source": [
        "# Initialize current pulse, membrane and outputs\n",
        "cur_in1 = torch.cat((torch.zeros(10), torch.ones(20)*(0.1), torch.zeros(170)), 0)  # input turns on at t=10, off at t=30\n",
        "mem = torch.zeros(1)\n",
        "spk_out = torch.zeros(1)\n",
        "mem_rec1 = [mem]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "O8JFL9CGa0NW"
      },
      "outputs": [],
      "source": [
        "# neuron simulation\n",
        "for step in range(num_steps):\n",
        "  spk_out, mem = lif1(cur_in1[step], mem)\n",
        "  mem_rec1.append(mem)\n",
        "mem_rec1 = torch.stack(mem_rec1)\n",
        "\n",
        "plot_current_pulse_response(cur_in1, mem_rec1, \"Lapicque's Neuron Model With Input Pulse\", \n",
        "                            vline1=10, vline2=30)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "orJvjev6a9tR"
      },
      "source": [
        "$U_{\\rm mem}$ rises just as it did for the step input, but now it decays with a time constant of $\\tau$ as in our first simulation. \n",
        "\n",
        "Let's deliver approximately the same amount of charge $Q = I \\times t$ to the circuit in half the time. This means the input current amplitude must be increased by a little, and the time window must be decreased."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "j_G_gWTzbo4J"
      },
      "outputs": [],
      "source": [
        "# Increase amplitude of current pulse; half the time.\n",
        "cur_in2 = torch.cat((torch.zeros(10), torch.ones(10)*0.111, torch.zeros(180)), 0)  # input turns on at t=10, off at t=20\n",
        "mem = torch.zeros(1)\n",
        "spk_out = torch.zeros(1)\n",
        "mem_rec2 = [mem]\n",
        "\n",
        "# neuron simulation\n",
        "for step in range(num_steps):\n",
        "  spk_out, mem = lif1(cur_in2[step], mem)\n",
        "  mem_rec2.append(mem)\n",
        "mem_rec2 = torch.stack(mem_rec2)\n",
        "\n",
        "plot_current_pulse_response(cur_in2, mem_rec2, \"Lapicque's Neuron Model With Input Pulse: x1/2 pulse width\",\n",
        "                            vline1=10, vline2=20)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "34azxvjWbqge"
      },
      "source": [
        "Let's do that again, but with an even faster input pulse and higher amplitude:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "650G94v9dgr3"
      },
      "outputs": [],
      "source": [
        "# Increase amplitude of current pulse; quarter the time.\n",
        "cur_in3 = torch.cat((torch.zeros(10), torch.ones(5)*0.147, torch.zeros(185)), 0)  # input turns on at t=10, off at t=15\n",
        "mem = torch.zeros(1)\n",
        "spk_out = torch.zeros(1)\n",
        "mem_rec3 = [mem]\n",
        "\n",
        "# neuron simulation\n",
        "for step in range(num_steps):\n",
        "  spk_out, mem = lif1(cur_in3[step], mem)\n",
        "  mem_rec3.append(mem)\n",
        "mem_rec3 = torch.stack(mem_rec3)\n",
        "\n",
        "plot_current_pulse_response(cur_in3, mem_rec3, \"Lapicque's Neuron Model With Input Pulse: x1/4 pulse width\",\n",
        "                            vline1=10, vline2=15)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "fuQtNgw4sYhk"
      },
      "source": [
        "Now compare all three experiments on the same plot:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "lZ8iBAjVM9FO"
      },
      "outputs": [],
      "source": [
        "compare_plots(cur_in1, cur_in2, cur_in3, mem_rec1, mem_rec2, mem_rec3, 10, 15, \n",
        "              20, 30, \"Lapicque's Neuron Model With Input Pulse: Varying inputs\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6WpZx0NStNdd"
      },
      "source": [
        "As the input current pulse amplitude increases, the rise time of the membrane potential speeds up. In the limit of the input current pulse width becoming infinitesimally small, $T_W \\rightarrow 0s$, the membrane potential will jump straight up in virtually zero rise time:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "h0EG17NpuzuN"
      },
      "outputs": [],
      "source": [
        "# Current spike input\n",
        "cur_in4 = torch.cat((torch.zeros(10), torch.ones(1)*0.5, torch.zeros(189)), 0)  # input only on for 1 time step\n",
        "mem = torch.zeros(1) \n",
        "spk_out = torch.zeros(1)\n",
        "mem_rec4 = [mem]\n",
        "\n",
        "# neuron simulation\n",
        "for step in range(num_steps):\n",
        "  spk_out, mem = lif1(cur_in4[step], mem)\n",
        "  mem_rec4.append(mem)\n",
        "mem_rec4 = torch.stack(mem_rec4)\n",
        "\n",
        "plot_current_pulse_response(cur_in4, mem_rec4, \"Lapicque's Neuron Model With Input Spike\", \n",
        "                            vline1=10, ylim_max1=0.6)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "DNY_EftJvREj"
      },
      "source": [
        "The current pulse width is now so short, it effectively looks like a spike. That is to say, charge is delivered in an infinitely short period of time, $I_{\\rm in}(t) = Q/t_0$ where $t_0 \\rightarrow 0$. More formally:\n",
        "\n",
        "$$I_{\\rm in}(t) = Q \\delta (t-t_0),$$\n",
        "\n",
        "where $\\delta (t-t_0)$ is the Dirac-Delta function. Physically, it is impossible to 'instantaneously' deposit charge. But integrating $I_{\\rm in}$ gives a result that makes physical sense, as we can obtain the charge delivered:\n",
        "\n",
        "$$1 = \\int^{t_0 + a}_{t_0 - a}\\delta(t-t_0)dt$$\n",
        "\n",
        "$$f(t_0) = \\int^{t_0 + a}_{t_0 - a}f(t)\\delta(t-t_0)dt$$\n",
        "\n",
        "Here, $f(t_0) = I_{\\rm in}(t_0=10) = 0.5A \\implies f(t) = Q = 0.5C$.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5kQ6MteCpMnk"
      },
      "source": [
        "Hopefully you have a good feel of how the membrane potential leaks at rest, and integrates the input current. That covers the 'leaky' and 'integrate' part of the neuron. How about the fire? "
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "CUH8IYh9xlV-"
      },
      "source": [
        "## 3.4 Lapicque: Firing\n",
        "\n",
        "So far, we have only seen how a neuron will react to spikes at the input. For a neuron to generate and emit its own spikes at the output, the passive membrane model must be combined with a threshold.\n",
        "\n",
        "If the membrane potential exceeds this threshold, then a voltage spike will be generated, external to the passive membrane model. \n",
        "\n",
        "<center>\n",
        "<img src='https://github.com/jeshraghian/snntorch/blob/master/docs/_static/img/examples/tutorial2/2_4_spiking.png?raw=true' width=\"450\">\n",
        "</center>\n",
        "\n",
        "Modify the `leaky_integrate_neuron` function from before to add a spike response."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "zNCfZ6dQkuwj"
      },
      "outputs": [],
      "source": [
        "# R=5.1, C=5e-3 for illustrative purposes\n",
        "def leaky_integrate_and_fire(mem, cur=0, threshold=1, time_step=1e-3, R=5.1, C=5e-3):\n",
        "  tau_mem = R*C\n",
        "  spk = (mem > threshold) # if membrane exceeds threshold, spk=1, else, 0\n",
        "  mem = mem + (time_step/tau_mem)*(-mem + cur*R)\n",
        "  return mem, spk"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "g3ZazVZkljmF"
      },
      "source": [
        "Set `threshold=1`, and apply a step current to get this neuron spiking."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "19gpai24k9Ex"
      },
      "outputs": [],
      "source": [
        "# Small step current input\n",
        "cur_in = torch.cat((torch.zeros(10), torch.ones(190)*0.2), 0)\n",
        "mem = torch.zeros(1)\n",
        "mem_rec = []\n",
        "spk_rec = []\n",
        "\n",
        "# neuron simulation\n",
        "for step in range(num_steps):\n",
        "  mem, spk = leaky_integrate_and_fire(mem, cur_in[step])\n",
        "  mem_rec.append(mem)\n",
        "  spk_rec.append(spk)\n",
        "\n",
        "# convert lists to tensors\n",
        "mem_rec = torch.stack(mem_rec)\n",
        "spk_rec = torch.stack(spk_rec)\n",
        "\n",
        "plot_cur_mem_spk(cur_in, mem_rec, spk_rec, thr_line=1, vline=109, ylim_max2=1.3, \n",
        "                 title=\"LIF Neuron Model With Uncontrolled Spiking\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "QEBX8ICtn9tY"
      },
      "source": [
        "Oops - the output spikes have gone out of control! This is because we forgot to add a reset mechanism. In reality, each time a neuron fires, the membrane potential hyperpolarizes back to its resting potential.\n",
        "\n",
        "Implementing this reset mechanism into our neuron:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "oj0MJmApoRlH"
      },
      "outputs": [],
      "source": [
        "# LIF w/Reset mechanism\n",
        "def leaky_integrate_and_fire(mem, cur=0, threshold=1, time_step=1e-3, R=5.1, C=5e-3):\n",
        "  tau_mem = R*C\n",
        "  spk = (mem > threshold)\n",
        "  mem = mem + (time_step/tau_mem)*(-mem + cur*R) - spk*threshold  # every time spk=1, subtract the threhsold\n",
        "  return mem, spk"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "7VZkftxRoa3n"
      },
      "outputs": [],
      "source": [
        "# Small step current input\n",
        "cur_in = torch.cat((torch.zeros(10), torch.ones(190)*0.2), 0)\n",
        "mem = torch.zeros(1)\n",
        "mem_rec = []\n",
        "spk_rec = []\n",
        "\n",
        "# neuron simulation\n",
        "for step in range(num_steps):\n",
        "  mem, spk = leaky_integrate_and_fire(mem, cur_in[step])\n",
        "  mem_rec.append(mem)\n",
        "  spk_rec.append(spk)\n",
        "\n",
        "# convert lists to tensors\n",
        "mem_rec = torch.stack(mem_rec)\n",
        "spk_rec = torch.stack(spk_rec)\n",
        "\n",
        "plot_cur_mem_spk(cur_in, mem_rec, spk_rec, thr_line=1, vline=109, ylim_max2=1.3, \n",
        "                 title=\"LIF Neuron Model With Reset\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Bf91G6E3lZFx"
      },
      "source": [
        "Bam. We now have a functional leaky integrate-and-fire neuron model! \n",
        "\n",
        "Note that if $I_{\\rm in}=0.2 A$ and $R<5 \\Omega$, then $I\\times R < 1 V$. If `threshold=1`, then no spiking would occur. Feel free to go back up, change the values, and test it out."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "InxmKoplUnoB"
      },
      "source": [
        "As before, all of that code is condensed by calling the built-in Lapicque neuron model from snnTorch:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "HPx-uuH90JTo"
      },
      "outputs": [],
      "source": [
        "# Create the same neuron as before using snnTorch\n",
        "lif2 = snn.Lapicque(R=5.1, C=5e-3, time_step=1e-3)\n",
        "\n",
        "print(f\"Membrane potential time constant: {lif2.R * lif2.C:.3f}s\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "LU9dSGi6zEOP"
      },
      "outputs": [],
      "source": [
        "# Initialize inputs and outputs\n",
        "cur_in = torch.cat((torch.zeros(10), torch.ones(190)*0.2), 0)\n",
        "mem = torch.zeros(1)\n",
        "spk_out = torch.zeros(1) \n",
        "mem_rec = [mem]\n",
        "spk_rec = [spk_out]\n",
        "\n",
        "# Simulation run across 100 time steps.\n",
        "for step in range(num_steps):\n",
        "  spk_out, mem = lif2(cur_in[step], mem)\n",
        "  mem_rec.append(mem)\n",
        "  spk_rec.append(spk_out)\n",
        "\n",
        "# convert lists to tensors\n",
        "mem_rec = torch.stack(mem_rec)\n",
        "spk_rec = torch.stack(spk_rec)\n",
        "\n",
        "plot_cur_mem_spk(cur_in, mem_rec, spk_rec, thr_line=1, vline=109, ylim_max2=1.3, \n",
        "                 title=\"Lapicque Neuron Model With Step Input\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "cm5tEp4YEoQm"
      },
      "source": [
        "The membrane potential exponentially rises and then hits the threshold, at which point it resets. We can roughly see this occurs between $105ms < t_{\\rm spk} < 115ms$. As a matter of curiousity, let's see what the spike recording actually consists of:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "tWSGUUsgNr9D"
      },
      "outputs": [],
      "source": [
        "print(spk_rec[105:115].view(-1))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4xWW-ncoOWrJ"
      },
      "source": [
        "The absence of a spike is represented by $S_{\\rm out}=0$, and the occurrence of a spike is $S_{\\rm out}=1$. Here, the spike occurs at $S_{\\rm out}[t=109]=1$. If you are wondering why each of these entries is stored as a tensor, it is because in future tutorials we will simulate large scale neural networks. Each entry will contain the spike responses of many neurons, and tensors can be loaded into GPU memory to speed up the training process.\n",
        "\n",
        "If $I_{\\rm in}$ is increased, then the membrane potential approaches the threshold $U_{\\rm thr}$ faster:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "HY6G9HZ1POC7"
      },
      "outputs": [],
      "source": [
        "# Initialize inputs and outputs\n",
        "cur_in = torch.cat((torch.zeros(10), torch.ones(190)*0.3), 0)  # increased current\n",
        "mem = torch.zeros(1)\n",
        "spk_out = torch.zeros(1) \n",
        "mem_rec = [mem]\n",
        "spk_rec = [spk_out]\n",
        "\n",
        "# neuron simulation\n",
        "for step in range(num_steps):\n",
        "  spk_out, mem = lif2(cur_in[step], mem)\n",
        "  mem_rec.append(mem)\n",
        "  spk_rec.append(spk_out)\n",
        "\n",
        "# convert lists to tensors\n",
        "mem_rec = torch.stack(mem_rec)\n",
        "spk_rec = torch.stack(spk_rec)\n",
        "\n",
        "\n",
        "plot_cur_mem_spk(cur_in, mem_rec, spk_rec, thr_line=1, ylim_max2=1.3, \n",
        "                 title=\"Lapicque Neuron Model With Periodic Firing\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "GOjWXepnPleo"
      },
      "source": [
        "A similar increase in firing frequency can also be induced by decreasing the threshold. This requires initializing a new neuron model, but the rest of the code block is the exact same as above:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "HQ0VIBlJPrvr"
      },
      "outputs": [],
      "source": [
        "# neuron with halved threshold\n",
        "lif3 = snn.Lapicque(R=5.1, C=5e-3, time_step=1e-3, threshold=0.5)\n",
        "\n",
        "# Initialize inputs and outputs\n",
        "cur_in = torch.cat((torch.zeros(10), torch.ones(190)*0.3), 0) \n",
        "mem = torch.zeros(1)\n",
        "spk_out = torch.zeros(1) \n",
        "mem_rec = [mem]\n",
        "spk_rec = [spk_out]\n",
        "\n",
        "# Neuron simulation\n",
        "for step in range(num_steps):\n",
        "  spk_out, mem = lif3(cur_in[step], mem)\n",
        "  mem_rec.append(mem)\n",
        "  spk_rec.append(spk_out)\n",
        "\n",
        "# convert lists to tensors\n",
        "mem_rec = torch.stack(mem_rec)\n",
        "spk_rec = torch.stack(spk_rec)\n",
        "\n",
        "plot_cur_mem_spk(cur_in, mem_rec, spk_rec, thr_line=0.5, ylim_max2=1.3, \n",
        "                 title=\"Lapicque Neuron Model With Lower Threshold\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "rJwtafQ_Siav"
      },
      "source": [
        "That's what happens for a constant current injection. But in both deep neural networks and in the biological brain, most neurons will be connected to other neurons. They are more likely to receive spikes, rather than injections of constant current. "
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "sOr9akj9S2Ei"
      },
      "source": [
        "## 3.5 Lapicque: Spike Inputs"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "rw_FZDnMS-pc"
      },
      "source": [
        "Let's harness some of the skills we learnt in [Tutorial 1](https://colab.research.google.com/github/jeshraghian/snntorch/blob/tutorials/examples/tutorial_1_spikegen.ipynb), and use the `snntorch.spikegen` module to create some randomly generated input spikes."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "V7AXG4QqTCoZ"
      },
      "outputs": [],
      "source": [
        "# Create a 1-D random spike train. Each element has a probability of 40% of firing.\n",
        "spk_in = spikegen.rate_conv(torch.ones((num_steps)) * 0.40)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5CJxZg1TTNCJ"
      },
      "source": [
        "Run the following code block to see how many spikes have been generated."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "rbhLmpNSTPgK"
      },
      "outputs": [],
      "source": [
        "print(f\"There are {int(sum(spk_in))} total spikes out of {len(spk_in)} time steps.\")\n",
        "\n",
        "fig = plt.figure(facecolor=\"w\", figsize=(8, 1))\n",
        "ax = fig.add_subplot(111)\n",
        "\n",
        "splt.raster(spk_in.reshape(num_steps, -1), ax, s=100, c=\"black\", marker=\"|\")\n",
        "plt.title(\"Input Spikes\")\n",
        "plt.xlabel(\"Time step\")\n",
        "plt.yticks([])\n",
        "plt.show()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "wo2_Voc3TvSc"
      },
      "outputs": [],
      "source": [
        "# Initialize inputs and outputs\n",
        "mem = torch.ones(1)*0.5\n",
        "spk_out = torch.zeros(1)\n",
        "mem_rec = [mem]\n",
        "spk_rec = [spk_out]\n",
        "\n",
        "# Neuron simulation\n",
        "for step in range(num_steps):\n",
        "  spk_out, mem = lif3(spk_in[step], mem)\n",
        "  spk_rec.append(spk_out)\n",
        "  mem_rec.append(mem)\n",
        "\n",
        "# convert lists to tensors\n",
        "mem_rec = torch.stack(mem_rec)\n",
        "spk_rec = torch.stack(spk_rec)\n",
        "\n",
        "plot_spk_mem_spk(spk_in, mem_rec, spk_out, \"Lapicque's Neuron Model With Input Spikes\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "lQsH8VOGXF2Q"
      },
      "source": [
        "## 3.6 Lapicque: Reset Mechanisms\n",
        "We already implemented a reset mechanism from scratch, but let's dive a little deeper. This sharp drop of membrane potential promotes a reduction of spike generation, which supplements part of the theory on how brains are so power efficient. Biologically, this drop of membrane potential is known as 'hyperpolarization'. Following that, it is momentarily more difficult to elicit another spike from the neuron. Here, we use a reset mechanism to model hyperpolarization.\n",
        "\n",
        "There are two ways to implement the reset mechanism:\n",
        "\n",
        "1.  *reset by subtraction* (default) $-$ subtract the threshold from the membrane potential each time a spike is generated;\n",
        "2.  *reset to zero* $-$ force the membrane potential to zero each time a spike is generated.\n",
        "3.  *no reset* $-$ do nothing, and let the firing go potentially uncontrolled.\n",
        "\n",
        "<center>\n",
        "<img src='https://github.com/jeshraghian/snntorch/blob/master/docs/_static/img/examples/tutorial2/2_5_reset.png?raw=true' width=\"450\">\n",
        "</center>\n",
        "\n",
        "Instantiate another neuron model to demonstrate how to alternate between reset mechanisms. \n",
        "By default, snnTorch neuron models use `reset_mechanism = \"subtract\"`. This can be explicitly overridden by passing the argument `reset_mechanism =  \"zero\"`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "7bfidqhKXdri"
      },
      "outputs": [],
      "source": [
        "# Neuron with reset_mechanism set to \"zero\"\n",
        "lif4 = snn.Lapicque(R=5.1, C=5e-3, time_step=1e-3, threshold=0.5, reset_mechanism=\"zero\")\n",
        "\n",
        "# Initialize inputs and outputs\n",
        "spk_in = spikegen.rate_conv(torch.ones((num_steps)) * 0.40)\n",
        "mem = torch.ones(1)*0.5\n",
        "spk_out = torch.zeros(1)\n",
        "mem_rec0 = [mem]\n",
        "spk_rec0 = [spk_out]\n",
        "\n",
        "# Neuron simulation\n",
        "for step in range(num_steps):\n",
        "  spk_out, mem = lif4(spk_in[step], mem)\n",
        "  spk_rec0.append(spk_out)\n",
        "  mem_rec0.append(mem)\n",
        "\n",
        "# convert lists to tensors\n",
        "mem_rec0 = torch.stack(mem_rec0)\n",
        "spk_rec0 = torch.stack(spk_rec0)\n",
        "\n",
        "plot_reset_comparison(spk_in, mem_rec, spk_rec, mem_rec0, spk_rec0)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Dzs7qX2Kb8j_"
      },
      "source": [
        "Pay close attention to the evolution of the membrane potential, especially in the moments after it reaches the threshold. You may notice that for \"Reset to Zero\", the membrane potential is forced back to zero after each spike.\n",
        "\n",
        "So which one is better? Applying `\"subtract\"` (the default value in `reset_mechanism`) is less lossy, because it does not ignore how much the membrane exceeds the threshold by.\n",
        "\n",
        "On the other hand, applying a hard reset with `\"zero\"` promotes sparsity and potentially less power consumption when running on dedicated neuromorphic hardware. Both options are available for you to experiment with. \n",
        "\n",
        "That covers the basics of a LIF neuron model!"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "PYWhh7idiS9t"
      },
      "source": [
        "# Conclusion"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "BGCY7g7R0Usy"
      },
      "source": [
        "In practice, we probably wouldn't use this neuron model to train a neural network. The Lapicque LIF model has added a lot of hyperparameters to tune: $R$, $C$, $\\Delta t$, $U_{\\rm thr}$, and the choice of reset mechanism. It's all a little bit daunting. So the [next tutorial](https://snntorch.readthedocs.io/en/latest/tutorials/index.html) will eliminate most of these hyperparameters, and introduce a neuron model that is better suited for large-scale deep learning. \n",
        "\n",
        "For reference, the documentation [can be found here](https://snntorch.readthedocs.io/en/latest/snntorch.html).\n",
        "\n",
        "If you like this project, please consider starring ⭐ the repo on GitHub as it is the easiest and best way to support it."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "OTRCyd0Xa-QK"
      },
      "source": [
        "## Further Reading\n",
        "* [Check out the snnTorch GitHub project here.](https://github.com/jeshraghian/snntorch)\n",
        "* [snnTorch documentation](https://snntorch.readthedocs.io/en/latest/snntorch.html) of the Lapicque, Leaky, Synaptic, and Alpha models\n",
        "* [*Neuronal Dynamics:\n",
        "From single neurons to networks and models of cognition*](https://neuronaldynamics.epfl.ch/index.html) by\n",
        "Wulfram Gerstner, Werner M. Kistler, Richard Naud and Liam Paninski.\n",
        "* [Theoretical Neuroscience: Computational and Mathematical Modeling of Neural Systems](https://mitpress.mit.edu/books/theoretical-neuroscience) by Laurence F. Abbott and Peter Dayan"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "include_colab_link": true,
      "name": "Untitled17.ipynb",
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.6.8"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 2
}
